{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "87e8360b-8d08-44bc-9333-79ba949afe8c",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "# Accelerating Hugging Face Gemma Inference with Transformer Engine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2da33092-eef5-46a4-b222-0188cc6e5079",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Introduction\n",
    "\n",
    "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/generation_animation.gif\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
    "<figcaption>\n",
    "Animation 1: Hugging Face Gemma model token generation.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
    "\n",
    "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
    "\n",
    "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
    "\n",
    "### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n",
    "\n",
    "The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n",
    "\n",
    "\n",
    "Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n",
    "\n",
    "1. Internal fragmentation\n",
    "2. External Fragmentation\n",
    "\n",
    "More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n",
    "\n",
    "\n",
    "Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n",
    "\n",
    "\n",
    "### 2. CUDA Graphs API\n",
    "\n",
    "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n",
    "\n",
    "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n",
    "\n",
    "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/graphs.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
    "<figcaption>\n",
    "Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "### 3. FP8 Scaling Factors Calibration\n",
    "\n",
    "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n",
    "\n",
    "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
    "\n",
    "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n",
    "\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/calibration.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
    "<figcaption>\n",
    "Figure 2:\n",
    "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "### 4. FP8 Model Weights\n",
    "\n",
    "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n",
    "\n",
    "The Transformer Engine includes a wrapper `quantized_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/fp8_model_init.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
    "<figcaption>\n",
    "Figure 3: Model under <b>autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>quantized_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "### Benchmarking\n",
    "\n",
    "We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "<b>Note</b>\n",
    "    \n",
    "This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b18f91a9",
   "metadata": {},
   "source": [
    "## Dependencies for this tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5201d77",
   "metadata": {},
   "source": [
    "The following files and media are necessary to effectively run this tutorial:\n",
    "\n",
    "1. `te_gemma.py`\n",
    "    - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n",
    "2. `te_gemma_loading_weights.py`\n",
    "    - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n",
    "3. `utils.py`\n",
    "    - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n",
    "4. `requirements.txt`\n",
    "    - This file contains the necessary Python packages for this tutorial.\n",
    "5. `media/`\n",
    "    - This directory contains the images and other artefacts used in this tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36767694-a1c5-4a00-a075-7addc55d8307",
   "metadata": {},
   "source": [
    "### Setup and checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment and run this cell when running the tutorial for the first time\n",
    "# %pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c756ebbd-24c9-4a54-a381-e7c02c555206",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import torch\n",
    "cudnn_version = torch.backends.cudnn.version()\n",
    "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8dfabbf",
   "metadata": {},
   "source": [
    "## [Baseline] Running Hugging Face generation with Gemma model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59560bff",
   "metadata": {},
   "source": [
    "HuggingFace Transformers library offers generation API. \n",
    "HuggingFace generation for the Gemma model will be used as a baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2803e0ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================== Generation example 1 ==============================\n",
      "Prompt: \"Here are the two facts about GPUs:\"\n",
      "Generated text: \"\n",
      "\n",
      "1. They are very good at doing a lot of the same thing at the same time.\n",
      "2. They are very bad at doing different things at the same time.\n",
      "\n",
      "The first fact is why GPUs are so good at graphics. The\"\n",
      "============================== Generation example 2 ==============================\n",
      "Prompt: \"Some facts about NVIDIA:\"\n",
      "Generated text: \"\n",
      "\n",
      "* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n",
      "* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n",
      "\n",
      "================================================================================\n",
      "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
      "Time: 46.60 s.\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "# Set specific hyperparameters\n",
    "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
    "run_config.batch_size = 64\n",
    "run_config.max_seq_length = 512\n",
    "\n",
    "model = init_baseline_model(run_config)\n",
    "\n",
    "print_sample_of_generated_texts(model, run_config)\n",
    "benchmark_generation(model, run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3698dc6",
   "metadata": {},
   "source": [
    "Let's put this time into the table for later comparison.\n",
    "\n",
    "| Models                                                      | Time | Speedup |  \n",
    "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n",
    "| HF (baseline)                                               | 46.6 s      | -                         |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bb40f45",
   "metadata": {},
   "source": [
    "## [Optimization 1] Accelerating generation with Transformer Engine "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "263b40f2",
   "metadata": {},
   "source": [
    "Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9dceef93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================== Generation example 1 ==============================\n",
      "Prompt: \"Here are the two facts about GPUs:\"\n",
      "Generated text: \"\n",
      "\n",
      "1. They are very good at doing a lot of the same thing at the same time.\n",
      "2. They are very bad at doing different things at the same time.\n",
      "\n",
      "The first fact is why they are so good at graphics. The second\"\n",
      "============================== Generation example 2 ==============================\n",
      "Prompt: \"Some facts about NVIDIA:\"\n",
      "Generated text: \"\n",
      "\n",
      "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
      "* NVIDIA is the world leader in AI computing.\n",
      "* NVIDIA is the world leader in graphics processing units (GP\"\n",
      "\n",
      "================================================================================\n",
      "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
      "Time: 12.25 s.\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "# Set specific hyperparameters\n",
    "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
    "run_config.batch_size = 64\n",
    "run_config.max_seq_length = 512\n",
    "run_config.is_paged = False  # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
    "\n",
    "model = init_te_gemma_model(run_config)\n",
    "\n",
    "print_sample_of_generated_texts(model, run_config)\n",
    "benchmark_generation(model, run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5d40836",
   "metadata": {},
   "source": [
    "With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "006d18e8",
   "metadata": {},
   "source": [
    "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
    "|---|---|---|---|---|\n",
    "| HF (baseline) | 46.6 s | - | - | - |\n",
    "| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21a89d9c",
   "metadata": {},
   "source": [
    "## [Optimization 2] More acceleration with CUDA Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2d53e7b",
   "metadata": {},
   "source": [
    "Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n",
    "```python\n",
    "    def __init__(self, config : GemmaConfig):\n",
    "        \"\"\"\n",
    "        Here \"the trick\" happens. `_model_context_phase` and\n",
    "        `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n",
    "        their recorded version. Once the graphs are recorded, they can be\n",
    "        replayed with minimal usage of CPU and that leads to speedup.\n",
    "        \"\"\"\n",
    "        (...)\n",
    "        # Record the graph for context/prefill phase.\n",
    "        self._model_context_phase = \n",
    "            self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n",
    "\n",
    "        (...)        \n",
    "        # Record the graph for generation phase.\n",
    "        self._model_generation_phase = \n",
    "            self.record_graph(self._model_generation_phase, self.generation_buffer)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def record_graph(self, function, input_tensor):\n",
    "        \"\"\"\n",
    "        Records the graph for the given function. The function is invoked on\n",
    "        argument (self.hidden_states,) and all kernels are recorded.\n",
    "        It then returns the captured callable, which can be run later while\n",
    "        minimizing CPU usage.\n",
    "        \"\"\"\n",
    "        fp8_recipe = get_default_fp8_recipe()\n",
    "\n",
    "        # We need both autocasts: FP8 for operations that can run in lower\n",
    "        # precision and BF16 for those that cannot.\n",
    "        with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n",
    "            graphed_function = te.pytorch.make_graphed_callables(\n",
    "                function,\n",
    "                (input_tensor,),\n",
    "                enabled=self.config.fp8,\n",
    "                recipe=fp8_recipe,\n",
    "                allow_unused_input=True,\n",
    "                num_warmup_iters=5,\n",
    "                sample_kwargs=sample_kwargs,\n",
    "            )\n",
    "        return graphed_function\n",
    "```\n",
    "\n",
    "It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n",
    "\n",
    "*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "31a3a8a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================== Generation example 1 ==============================\n",
      "Prompt: \"Here are the two facts about GPUs:\"\n",
      "Generated text: \"\n",
      "\n",
      "1. They are very good at doing a lot of the same thing at the same time.\n",
      "2. They are very bad at doing different things at the same time.\n",
      "\n",
      "The first fact is why they are so good at graphics. The second\"\n",
      "============================== Generation example 2 ==============================\n",
      "Prompt: \"Some facts about NVIDIA:\"\n",
      "Generated text: \"\n",
      "\n",
      "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
      "* NVIDIA is the world leader in AI computing.\n",
      "* NVIDIA is the world leader in graphics processing units (GP\"\n",
      "\n",
      "================================================================================\n",
      "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
      "Time: 6.39 s.\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "# Set specific hyperparameters\n",
    "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
    "run_config.max_seq_length = 512\n",
    "run_config.batch_size = 64\n",
    "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
    "\n",
    "# It is necessary to preallocate a static buffer.\n",
    "# CUDA graphs require static input tensors for every kernel.\n",
    "# This approach may result in a slight increase in memory consumption;\n",
    "# however, the substantial speedup achieved makes it worthwhile.\n",
    "run_config.generation_cuda_graphs = True\n",
    "run_config.cuda_graphs_static_batch_size = 64\n",
    "run_config.cuda_graphs_static_max_seq_len = 512\n",
    "run_config.cuda_graphs_static_max_context_len = 512\n",
    "\n",
    "model = init_te_gemma_model(run_config)\n",
    "\n",
    "print_sample_of_generated_texts(model, run_config)\n",
    "benchmark_generation(model, run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53bb430f",
   "metadata": {},
   "source": [
    "A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n",
    "\n",
    "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
    "|---|---|---|---|---|\n",
    "| HF (baseline) | 46.6 s | - | - | - |\n",
    "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
    "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a11b75c",
   "metadata": {},
   "source": [
    "Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n",
    "\n",
    "1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n",
    "2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n",
    "3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n",
    "4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n",
    "\n",
    "(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/transformer_cuda_graphed.png\" width=\"80%\" \">\n",
    "<figcaption>\n",
    "    \n",
    "Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n",
    "</figcaption>\n",
    "</figure>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6b171a0",
   "metadata": {},
   "source": [
    "## [Optimization 3] Even more acceleration with FP8 precision "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a80288b",
   "metadata": {},
   "source": [
    "### Calibrating FP8 scaling factors for correctness\n",
    "\n",
    "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
    "\n",
    "1. Model weight tensors\n",
    "2. Input tensors\n",
    "\n",
    "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n",
    "\n",
    "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
    "\n",
    "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n",
    " \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/calibration_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
    "<figcaption>\n",
    "    Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "\n",
    "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aecee0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "import transformer_engine.pytorch as te\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "run_config.fuse_qkv_params = True\n",
    "model = init_te_gemma_model(run_config)\n",
    "\n",
    "# Calibration\n",
    "with te.autocast(enabled=False, calibrating=True), torch.autocast(\n",
    "    device_type=\"cuda\", dtype=torch.bfloat16\n",
    "):\n",
    "    model.train()\n",
    "    run_forward_pass(model, run_config, num_iters=64)\n",
    "\n",
    "# Compute scale_fwd with enabled fp8 autocast\n",
    "with te.autocast(enabled=True), torch.autocast(\n",
    "    device_type=\"cuda\", dtype=torch.bfloat16\n",
    "):\n",
    "    run_forward_pass(model, run_config, 1)\n",
    "\n",
    "# Some parameters are in pointing to the same tensors, double save is avoided here.\n",
    "dict_to_save = {\n",
    "    k: v\n",
    "    for k, v in model.state_dict().items()\n",
    "    if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n",
    "}\n",
    "torch.save(\n",
    "    dict_to_save, \"calibrated_weights.pth\"\n",
    ")  # <-- Add path to save calibrated weights."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6dcd135",
   "metadata": {},
   "source": [
    "### Generation with better FP8 scaling factors\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/calibration_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
    "<figcaption>\n",
    "    Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a913f54d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================== Generation example 1 ==============================\n",
      "Prompt: \"Here are the two facts about GPUs:\"\n",
      "Generated text: \"\n",
      "\n",
      "1. They are very good at doing the same thing over and over again.\n",
      "2. They are very bad at doing different things at the same time.\n",
      "\n",
      "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
      "============================== Generation example 2 ==============================\n",
      "Prompt: \"Some facts about NVIDIA:\"\n",
      "Generated text: \"\n",
      "\n",
      "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
      "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
      "* NVIDIA is a key player\"\n",
      "\n",
      "================================================================================\n",
      "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
      "Time: 8.73 s.\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "# Set specific hyperparameters\n",
    "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
    "run_config.fuse_qkv_params = True  # This is needed by the last improvement.\n",
    "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
    "\n",
    "# CUDA Graphs related config\n",
    "run_config.generation_cuda_graphs = True\n",
    "run_config.cuda_graphs_static_batch_size = 64\n",
    "run_config.cuda_graphs_static_max_seq_len = 512\n",
    "run_config.cuda_graphs_static_max_context_len = 512\n",
    "\n",
    "# Enable FP8\n",
    "run_config.fp8 = True\n",
    "# Calibrated fp8 weights are loaded directly from the file.\n",
    "run_config.fp8_model_weights_filename = (\n",
    "    \"calibrated_weights.pth\"  # <-- Add calibrated weights location here.\n",
    ")\n",
    "\n",
    "model = init_te_gemma_model(run_config)\n",
    "\n",
    "print_sample_of_generated_texts(model, run_config)\n",
    "benchmark_generation(model, run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cdbb56c",
   "metadata": {},
   "source": [
    "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n",
    "\n",
    "### Use of FP8-only model weights\n",
    "\n",
    "Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n",
    "\n",
    "This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/fp8_model_init_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
    "<figcaption>\n",
    "    Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n",
    "</figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d",
   "metadata": {},
   "source": [
    "However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n",
    "\n",
    "1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n",
    "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n",
    "\n",
    "\n",
    "Transformer Engine supports maintaining FP8-only weights with the `quantized_model_init` context manager. Let's see a small example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4562ee82-8c95-4736-8815-cd386078a485",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Memory required for 16384x16384 linear layer: \n",
      "FP32 - 1024.0 MB, \n",
      "BF16 - 512.0 MB, \n",
      "FP8 - 256.0 MB, \n",
      "\n",
      "Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n",
      "Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n",
      "Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import transformer_engine.pytorch as te\n",
    "\n",
    "H = 2**14\n",
    "D = 2**14\n",
    "print(f\"Memory required for {H}x{D} linear layer: \\n\"\n",
    "      f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n",
    "      f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n",
    "      f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n",
    "\n",
    "linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n",
    "print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
    "del linear_fp32\n",
    "\n",
    "linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n",
    "print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
    "del linear_bf16\n",
    "\n",
    "# Initialize model weights in FP8 precision\n",
    "with torch.no_grad(), te.quantized_model_init(enabled=True):\n",
    "    linear_fp8 = te.Linear(H, D)\n",
    "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
    "del linear_fp8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b",
   "metadata": {},
   "source": [
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"./media/fp8_model_init_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
    "<figcaption>\n",
    "    Figure 8: Using quantized_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
    "</figcaption>\n",
    "</figure>\n",
    "\n",
    "Let's run the code with `quantized_model_init`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "96264b9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================== Generation example 1 ==============================\n",
      "Prompt: \"Here are the two facts about GPUs:\"\n",
      "Generated text: \"\n",
      "\n",
      "1. They are very good at doing the same thing over and over again.\n",
      "2. They are very bad at doing different things at the same time.\n",
      "\n",
      "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
      "============================== Generation example 2 ==============================\n",
      "Prompt: \"Some facts about NVIDIA:\"\n",
      "Generated text: \"\n",
      "\n",
      "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
      "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
      "* NVIDIA is a key player\"\n",
      "\n",
      "================================================================================\n",
      "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
      "Time: 4.99 s.\n"
     ]
    }
   ],
   "source": [
    "# Restart the notebook (to flush the GPU memory)\n",
    "from utils import restart_jupyter_notebook\n",
    "restart_jupyter_notebook()\n",
    "\n",
    "# Import necessary packages and methods\n",
    "from utils import *\n",
    "\n",
    "# Provide Huggingface Access Token\n",
    "run_config.hf_access_token = \"\"\n",
    "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
    "run_config.model_name = \"google/gemma-7b\"\n",
    "\n",
    "# Provide a directory to cache weights in to avoid downloading them every time.\n",
    "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
    "run_config.weights_cache_dir = \"\"\n",
    "\n",
    "# Set specific hyperparameters\n",
    "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
    "run_config.fuse_qkv_params = True  # This is needed by the last improvement.\n",
    "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
    "\n",
    "# CUDA Graphs related config\n",
    "run_config.generation_cuda_graphs = True\n",
    "run_config.cuda_graphs_static_batch_size = 64\n",
    "run_config.cuda_graphs_static_max_seq_len = 512\n",
    "run_config.cuda_graphs_static_max_context_len = 512\n",
    "\n",
    "# Enable FP8 math and FP8 model weights\n",
    "run_config.fp8 = True\n",
    "run_config.quantized_model_init = True  # This will result in storing only fp8 weights.\n",
    "run_config.fp8_model_weights_filename = (\n",
    "    \"calibrated_weights.pth\"  # <-- Add calibrated weights location here.\n",
    ")\n",
    "\n",
    "model = init_te_gemma_model(run_config)\n",
    "\n",
    "print_sample_of_generated_texts(model, run_config)\n",
    "benchmark_generation(model, run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e30ca5a",
   "metadata": {},
   "source": [
    "The final speedup is **9.3x**. \n",
    "\n",
    "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
    "|---|---|---|---|---|\n",
    "| HF (baseline) | 46.6 s | - | - | - |\n",
    "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
    "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n",
    "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `quantized_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e87275",
   "metadata": {},
   "source": [
    "## Conclusions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bb2452d",
   "metadata": {},
   "source": [
    "This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n",
    "\n",
    "1. Support for KV Caching (both non-paged and paged),\n",
    "2. Integration with CUDA Graphs,\n",
    "3. FP8 scaling factors calibration,\n",
    "4. Keeping model parameters in FP8 precision.\n",
    "\n",
    "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n",
    "\n",
    "1. Longer context lengths (with paged KV cache) \n",
    "2. Using less memory during generation (by storing weights in FP8 precision using `quantized_model_init`)\n",
    "\n",
    "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
